Oxford Pets with Bounding Boxes¶
In [1]:
import os
import xml.etree.ElementTree as ET
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
class PetDetectionDataset(Dataset):
def __init__(self, root, transform=None):
"""
Custom dataset for Oxford-IIIT Pet:
- Loads images from './data/oxford-iiit-pet/images/'
- Reads bounding boxes from './data/oxford-iiit-pet/annotations/xmls/'
- Uses the filename to determine if the pet is a cat (0) or a dog (1).
"""
self.image_dir = os.path.join(root, "oxford-iiit-pet", "images")
self.annotation_dir = os.path.join(root, "oxford-iiit-pet", "annotations", "xmls")
self.transform = transform
# Get list of valid files (only those with a corresponding XML file)
self.image_files = []
self.bboxes = []
self.labels = []
for xml_file in os.listdir(self.annotation_dir):
if xml_file.endswith(".xml"):
image_name = xml_file.replace(".xml", ".jpg") # Image filename
image_path = os.path.join(self.image_dir, image_name)
xml_path = os.path.join(self.annotation_dir, xml_file)
# Ensure image file exists
if os.path.exists(image_path):
# Parse XML file to get bounding box
bbox = self.parse_xml(xml_path)
if bbox:
self.image_files.append(image_path)
self.bboxes.append(bbox)
# Extract breed name from filename
breed_name = "_".join(image_name.split("_")[:-1]) # Extract breed name
label = 0 if breed_name.islower() else 1 # Cat if lowercase, Dog if capitalized
self.labels.append(label)
def parse_xml(self, xml_file):
"""Extract bounding box coordinates from the XML annotation file."""
tree = ET.parse(xml_file)
root = tree.getroot()
bbox = None
for obj in root.findall("object"):
bndbox = obj.find("bndbox")
xmin = int(bndbox.find("xmin").text)
ymin = int(bndbox.find("ymin").text)
xmax = int(bndbox.find("xmax").text)
ymax = int(bndbox.find("ymax").text)
bbox = [xmin, ymin, xmax, ymax] # Format: (xmin, ymin, xmax, ymax)
break # Only take the first object (each image should have one pet)
return bbox
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
"""Loads an image, bounding box, and label."""
image_path = self.image_files[idx]
bbox = torch.tensor(self.bboxes[idx], dtype=torch.float32)
label = torch.tensor(self.labels[idx], dtype=torch.long)
# Load the image
image = Image.open(image_path).convert("RGB")
original_w, original_h = image.size # Get original image size
# Apply transformations (resize to 224x224)
if self.transform:
image = self.transform(image)
# Normalize bounding box coordinates relative to original image dimensions
bbox[0] /= original_w # Normalize xmin
bbox[1] /= original_h # Normalize ymin
bbox[2] /= original_w # Normalize xmax
bbox[3] /= original_h # Normalize ymax
return image, bbox, label
In [2]:
# Define image transformations
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
])
# Create dataset
data_root = "./data"
dataset = PetDetectionDataset(root=data_root, transform=transform)
# Split into training and validation sets (80/20)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# Check dataset size
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")
Train size: 2948, Validation size: 738
In [3]:
# Get a sample from the training loader
for images, bboxes, labels in train_loader:
sample_img = images[0]
sample_bbox = bboxes[0] # Format: (xmin, ymin, xmax, ymax)
sample_label = labels[0].item()
print(f"Label: {'Cat' if sample_label == 1 else 'Dog'}")
print(f"Bounding Box: {sample_bbox.numpy()}")
# Convert tensor image back to numpy for visualization
img_np = sample_img.permute(1, 2, 0).numpy()
plt.figure(figsize=(6, 6))
plt.imshow(img_np)
# Denormalize bbox coordinates (based on resized 224x224 image)
xmin = sample_bbox[0] * 224
ymin = sample_bbox[1] * 224
xmax = sample_bbox[2] * 224
ymax = sample_bbox[3] * 224
# Compute width and height
width = xmax - xmin
height = ymax - ymin
# Draw bounding box
ax = plt.gca()
rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor="r", facecolor="none")
ax.add_patch(rect)
plt.title("Sample Image with Computed Bounding Box")
plt.axis("off")
plt.show()
break # Only display one sample
Label: Dog Bounding Box: [0.204 0.08408409 0.81 0.8918919 ]
In [4]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def visualize_grid(dataloader, num_images=9):
"""
Display a 3x3 grid of images with bounding boxes and class labels.
"""
fig, axes = plt.subplots(3, 3, figsize=(12, 12)) # 3x3 grid
axes = axes.flatten() # Flatten to iterate easily
count = 0
for images, bboxes, labels in dataloader:
for i in range(min(num_images, len(images))):
if count >= num_images:
break
sample_img = images[i]
sample_bbox = bboxes[i] # Format: (xmin, ymin, xmax, ymax)
sample_label = labels[i].item()
# Convert tensor image back to numpy for visualization
img_np = sample_img.permute(1, 2, 0).numpy()
# Denormalize bbox coordinates (based on resized 224x224 image)
xmin = sample_bbox[0] * 224
ymin = sample_bbox[1] * 224
xmax = sample_bbox[2] * 224
ymax = sample_bbox[3] * 224
# Compute width and height
width = xmax - xmin
height = ymax - ymin
# Plot the image
ax = axes[count]
ax.imshow(img_np)
ax.set_xticks([])
ax.set_yticks([])
# Draw bounding box
rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor="r", facecolor="none")
ax.add_patch(rect)
# Add label
label_text = "Cat" if sample_label == 1 else "Dog"
ax.text(xmin, ymin - 5, label_text, color="white", fontsize=12,
bbox=dict(facecolor="red", alpha=0.5, edgecolor="none"))
count += 1
if count >= num_images:
break
if count >= num_images:
break
plt.tight_layout()
plt.show()
# Call the function
visualize_grid(train_loader)
In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision.models as models
class PetClassifierAndBBox(pl.LightningModule):
def __init__(self, lambda_bbox=5.0, lr=.001):
"""
PyTorch Lightning module for pet classification and bounding box detection.
- Uses EfficientNet as a feature extractor.
- Two heads:
- One for classification (binary: cat/dog)
- One for bounding box regression (x, y, width, height)
- Loss: cross-entropy (classification) + lambda_bbox * MSE (bounding box)
"""
super().__init__()
self.save_hyperparameters()
# Load pre-trained EfficientNet as feature extractor
efficientnet = models.efficientnet_b0(pretrained=True)
self.feature_extractor = efficientnet.features # Remove classification head
# Define a shared fully connected layer
self.fc = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # Global Average Pooling
nn.Flatten(),
nn.Linear(1280, 512), # EfficientNet-B0 has 1280 output features
nn.ReLU()
)
# Define classification head (binary classification: cat vs. dog)
self.classification_head = nn.Sequential(
nn.Linear(512,512),
nn.Linear(512, 2) # Output 2 classes
)
# Define bounding box head (regression: x, y, width, height)
self.bbox_head = nn.Sequential(
nn.Linear(512,512),
nn.Linear(512, 4) # Output 4 coordinates
)
# Loss weights
self.lambda_bbox = lambda_bbox
self.lr = lr
def forward(self, x):
"""Forward pass through feature extractor and both heads."""
features = self.feature_extractor(x) # Extract features
features = self.fc(features) # Pass through fully connected layer
class_logits = self.classification_head(features) # Classification head
bbox_preds = self.bbox_head(features) # Bounding box head
return class_logits, bbox_preds
def training_step(self, batch, batch_idx):
"""Training step: Compute loss and log metrics."""
images, bboxes, labels = batch # Unpack batch
class_logits, bbox_preds = self(images) # Forward pass
# Compute losses
loss_class = F.cross_entropy(class_logits, labels) # Classification loss
loss_bbox = F.mse_loss(bbox_preds, bboxes) # Bounding box regression loss
total_loss = loss_class + self.lambda_bbox * loss_bbox # Combined loss
# Logging
self.log("train_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True)
self.log("train_class_loss", loss_class, prog_bar=True, on_step=False, on_epoch=True)
self.log("train_bbox_loss", loss_bbox, prog_bar=True, on_step=False, on_epoch=True)
return total_loss
def validation_step(self, batch, batch_idx):
"""Validation step: Compute loss and log metrics."""
images, bboxes, labels = batch
class_logits, bbox_preds = self(images)
loss_class = F.cross_entropy(class_logits, labels)
loss_bbox = F.mse_loss(bbox_preds, bboxes)
total_loss = loss_class + self.lambda_bbox * loss_bbox
# Logging
self.log("val_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True)
self.log("val_class_loss", loss_class, prog_bar=True, on_step=False, on_epoch=True)
self.log("val_bbox_loss", loss_bbox, prog_bar=True, on_step=False, on_epoch=True)
return total_loss
def configure_optimizers(self):
"""Define optimizer and learning rate scheduler."""
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
In [11]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
# Set up logging and early stopping
csv_logger = CSVLogger(save_dir='logs/', name='SingleDetector', version="")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=25, verbose=True, mode="min")
# Create the model instance
model = PetClassifierAndBBox(lambda_bbox = 5)
# Assume train_loader and val_loader are defined DataLoaders
trainer = pl.Trainer(
max_epochs=50,
logger=csv_logger,
callbacks=[early_stop_callback]
)
trainer.fit(model, train_loader, val_loader)
# Save the final model state
trainer.save_checkpoint('logs/SingleDetector/final_model.ckpt')
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params | Mode ----------------------------------------------------------- 0 | feature_extractor | Sequential | 4.0 M | train 1 | fc | Sequential | 655 K | train 2 | classification_head | Sequential | 263 K | train 3 | bbox_head | Sequential | 264 K | train ----------------------------------------------------------- 5.2 M Trainable params 0 Non-trainable params 5.2 M Total params 20.767 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved. New best score: 0.189
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.181
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.168
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.027 >= min_delta = 0.0. New best score: 0.141
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.042 >= min_delta = 0.0. New best score: 0.099
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.090
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.089
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.089
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.079
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.076
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.070
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.067
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.065
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=50` reached.
In [12]:
# Load the final model checkpoint
model = PetClassifierAndBBox.load_from_checkpoint('logs/SingleDetector/final_model.ckpt')
# Set the model to evaluation mode
model.eval()
def visualize_predictions(dataloader, model, num_images=9):
"""
Display a 3x3 grid of images with true and predicted bounding boxes, predicted class labels, and probabilities.
"""
fig, axes = plt.subplots(3, 3, figsize=(12, 12)) # 3x3 grid
axes = axes.flatten() # Flatten to iterate easily
count = 0
for images, true_bboxes, true_labels in dataloader:
with torch.no_grad():
class_logits, pred_bboxes = model(images)
pred_probs = torch.softmax(class_logits, dim=1)
pred_labels = torch.argmax(class_logits, dim=1)
for i in range(min(num_images, len(images))):
if count >= num_images:
break
sample_img = images[i]
true_bbox = true_bboxes[i] # Format: (xmin, ymin, xmax, ymax)
pred_bbox = pred_bboxes[i] # Format: (xmin, ymin, xmax, ymax)
pred_label = pred_labels[i].item()
pred_prob = pred_probs[i][pred_label].item()
# Convert tensor image back to numpy for visualization
img_np = sample_img.permute(1, 2, 0).numpy()
# Denormalize true bbox coordinates (based on resized 224x224 image)
true_xmin = true_bbox[0] * 224
true_ymin = true_bbox[1] * 224
true_xmax = true_bbox[2] * 224
true_ymax = true_bbox[3] * 224
# Compute width and height for true bbox
true_width = true_xmax - true_xmin
true_height = true_ymax - true_ymin
# Denormalize predicted bbox coordinates (based on resized 224x224 image)
pred_xmin = pred_bbox[0] * 224
pred_ymin = pred_bbox[1] * 224
pred_xmax = pred_bbox[2] * 224
pred_ymax = pred_bbox[3] * 224
# Compute width and height for predicted bbox
pred_width = pred_xmax - pred_xmin
pred_height = pred_ymax - pred_ymin
# Plot the image
ax = axes[count]
ax.imshow(img_np)
ax.set_xticks([])
ax.set_yticks([])
# Draw true bounding box
true_rect = patches.Rectangle((true_xmin, true_ymin), true_width, true_height, linewidth=2, edgecolor="g", facecolor="none")
ax.add_patch(true_rect)
# Draw predicted bounding box
pred_rect = patches.Rectangle((pred_xmin, pred_ymin), pred_width, pred_height, linewidth=2, edgecolor="r", facecolor="none")
ax.add_patch(pred_rect)
# Add predicted label and probability
label_text = f"{'Cat' if pred_label == 1 else 'Dog'}: {pred_prob:.2f}"
ax.text(pred_xmin, pred_ymin - 10, label_text, color="white", fontsize=12, bbox=dict(facecolor="red", alpha=0.5, edgecolor="none"))
count += 1
if count >= num_images:
break
if count >= num_images:
break
plt.tight_layout()
plt.show()
# Call the function to visualize predictions
visualize_predictions(val_loader, model)